"""Runner for off-policy HARL algorithms."""
import torch
import numpy as np
import torch.nn.functional as F
from mafis.runners.off_policy_base_runner import OffPolicyBaseRunner
from mafis.utils.envs_tools import check


class OffPolicyHARunner(OffPolicyBaseRunner):
    """Runner for off-policy HA algorithms."""

    def train(self):
        """Train the model"""
        self.total_it += 1
        data = self.demo_buffer.sample()
        (
            sp_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            sp_obs,  # (n_agents, batch_size, dim)
            sp_actions,  # (n_agents, batch_size, dim)
            sp_available_actions,  # (n_agents, batch_size, dim)
            sp_reward,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_done,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_valid_transition,  # (n_agents, batch_size, 1)
            sp_term,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_next_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            sp_next_obs,  # (n_agents, batch_size, dim)
            sp_next_available_actions,  # (n_agents, batch_size, dim)
            sp_gamma,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
        ) = data
        # train critic
        # self.critic.turn_on_grad()
        # if self.args["algo"] == "hasac":
        #     next_actions = []
        #     next_logp_actions = []
        #     for agent_id in range(self.num_agents):
        #         next_action, next_logp_action = self.actor[
        #             agent_id
        #         ].get_actions_with_logprobs(
        #             sp_next_obs[agent_id],
        #             sp_next_available_actions[agent_id]
        #             if sp_next_available_actions is not None
        #             else None,
        #         )
        #         next_actions.append(next_action)
        #         next_logp_actions.append(next_logp_action)
        #     self.critic.train(
        #         sp_share_obs,
        #         sp_actions,
        #         sp_reward,
        #         sp_done,
        #         sp_valid_transition,
        #         sp_term,
        #         sp_next_share_obs,
        #         next_actions,
        #         next_logp_actions,
        #         sp_gamma,
        #         self.value_normalizer,
        #     )
        # else:
        #     next_actions = []
        #     for agent_id in range(self.num_agents):
        #         next_actions.append(
        #             self.actor[agent_id].get_target_actions(sp_next_obs[agent_id])
        #         )
        #     self.critic.train(
        #         sp_share_obs,
        #         sp_actions,
        #         sp_reward,
        #         sp_done,
        #         sp_term,
        #         sp_next_share_obs,
        #         next_actions,
        #         sp_gamma,
        #     )
        # self.critic.turn_off_grad()
        sp_valid_transition = torch.tensor(sp_valid_transition, device=self.device)
        if self.total_it % self.policy_freq == 0:
            # train actors
            if self.args["algo"] == "hasac":
                # actions = []
                # logp_actions = []
                # with torch.no_grad():
                #     for agent_id in range(self.num_agents):
                #         action, logp_action = self.actor[
                #             agent_id
                #         ].get_actions_with_logprobs(
                #             sp_obs[agent_id],
                #             sp_available_actions[agent_id]
                #             if sp_available_actions is not None
                #             else None,
                #         )
                #         actions.append(action)
                #         logp_actions.append(logp_action)
                # actions shape: (n_agents, batch_size, dim)
                # logp_actions shape: (n_agents, batch_size, 1)
                if self.fixed_order:
                    agent_order = list(range(self.num_agents))
                else:
                    agent_order = list(np.random.permutation(self.num_agents))
                for agent_id in agent_order:
                    self.actor[agent_id].turn_on_grad()
                    # train this agent
                    a, _ = self.actor[
                        agent_id
                    ].get_actions_with_logprobs(
                        sp_obs[agent_id],
                        sp_available_actions[agent_id]
                        if sp_available_actions is not None
                        else None,
                    )
                    actor_loss = F.mse_loss(a, torch.Tensor(sp_actions[agent_id]).cuda())
                    # if self.state_type == "EP":
                    #     logp_action = logp_actions[agent_id]
                    #     actions_t = torch.cat(actions, dim=-1)
                    # elif self.state_type == "FP":
                    #     logp_action = torch.tile(
                    #         logp_actions[agent_id], (self.num_agents, 1)
                    #     )
                    #     actions_t = torch.tile(
                    #         torch.cat(actions, dim=-1), (self.num_agents, 1)
                    #     )
                    # value_pred = self.critic.get_values(sp_share_obs, actions_t)
                    # if self.algo_args["algo"]["use_policy_active_masks"]:
                    #     if self.state_type == "EP":
                    #         actor_loss = (
                    #             -torch.sum(
                    #                 (value_pred - self.alpha[agent_id] * logp_action)
                    #                 * sp_valid_transition[agent_id]
                    #             )
                    #             / sp_valid_transition[agent_id].sum()
                    #         )
                    #     elif self.state_type == "FP":
                    #         valid_transition = torch.tile(
                    #             sp_valid_transition[agent_id], (self.num_agents, 1)
                    #         )
                    #         actor_loss = (
                    #             -torch.sum(
                    #                 (value_pred - self.alpha[agent_id] * logp_action)
                    #                 * valid_transition
                    #             )
                    #             / valid_transition.sum()
                    #         )
                    # else:
                    #     actor_loss = -torch.mean(
                    #         value_pred - self.alpha[agent_id] * logp_action
                    #     )
                    self.actor[agent_id].actor_optimizer.zero_grad()
                    actor_loss.backward()
                    self.actor[agent_id].actor_optimizer.step()
                    self.actor[agent_id].turn_off_grad()
                    # train this agent's alpha
                #     if self.algo_args["algo"]["auto_alpha"]:
                #         log_prob = (
                #             logp_actions[agent_id].detach()
                #             + self.target_entropy[agent_id]
                #         )
                #         alpha_loss = -(self.log_alpha[agent_id] * log_prob).mean()
                #         self.alpha_optimizer[agent_id].zero_grad()
                #         alpha_loss.backward()
                #         self.alpha_optimizer[agent_id].step()
                #         self.alpha[agent_id] = torch.exp(
                #             self.log_alpha[agent_id].detach()
                #         )
                #     actions[agent_id], _ = self.actor[
                #         agent_id
                #     ].get_actions_with_logprobs(
                #         sp_obs[agent_id],
                #         sp_available_actions[agent_id]
                #         if sp_available_actions is not None
                #         else None,
                #     )
                # # train critic's alpha
                # if self.algo_args["algo"]["auto_alpha"]:
                #     self.critic.update_alpha(logp_actions, np.sum(self.target_entropy))
            else:
                if self.args["algo"] == "had3qn":
                    actions = []
                    with torch.no_grad():
                        for agent_id in range(self.num_agents):
                            actions.append(
                                self.actor[agent_id].get_actions(
                                    sp_obs[agent_id], False
                                )
                            )
                    # actions shape: (n_agents, batch_size, 1)
                    update_actions, get_values = self.critic.train_values(
                        sp_share_obs, actions
                    )
                    if self.fixed_order:
                        agent_order = list(range(self.num_agents))
                    else:
                        agent_order = list(np.random.permutation(self.num_agents))
                    for agent_id in agent_order:
                        self.actor[agent_id].turn_on_grad()
                        # actor preds
                        actor_values = self.actor[agent_id].train_values(
                            sp_obs[agent_id], actions[agent_id]
                        )
                        # critic preds
                        critic_values = get_values()
                        # update
                        actor_loss = torch.mean(F.mse_loss(actor_values, critic_values))
                        self.actor[agent_id].actor_optimizer.zero_grad()
                        actor_loss.backward()
                        self.actor[agent_id].actor_optimizer.step()
                        self.actor[agent_id].turn_off_grad()
                        update_actions(agent_id)
                else:
                    actions = []
                    with torch.no_grad():
                        for agent_id in range(self.num_agents):
                            actions.append(
                                self.actor[agent_id].get_actions(
                                    sp_obs[agent_id], False
                                )
                            )
                    # actions shape: (n_agents, batch_size, dim)
                    if self.fixed_order:
                        agent_order = list(range(self.num_agents))
                    else:
                        agent_order = list(np.random.permutation(self.num_agents))
                    for agent_id in agent_order:
                        self.actor[agent_id].turn_on_grad()
                        # train this agent
                        actions[agent_id] = self.actor[agent_id].get_actions(
                            sp_obs[agent_id], False
                        )
                        actions_t = torch.cat(actions, dim=-1)
                        value_pred = self.critic.get_values(sp_share_obs, actions_t)
                        actor_loss = -torch.mean(value_pred)
                        self.actor[agent_id].actor_optimizer.zero_grad()
                        actor_loss.backward()
                        self.actor[agent_id].actor_optimizer.step()
                        self.actor[agent_id].turn_off_grad()
                        actions[agent_id] = self.actor[agent_id].get_actions(
                            sp_obs[agent_id], False
                        )
                # soft update
                for agent_id in range(self.num_agents):
                    self.actor[agent_id].soft_update()
            self.critic.soft_update()

    def train_MAFIS(self):
        """Train the model"""
        self.total_it += 1
        data = self.buffer.sample()
        (
            sp_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            sp_obs,  # (n_agents, batch_size, dim)
            sp_actions,  # (n_agents, batch_size, dim)
            sp_available_actions,  # (n_agents, batch_size, dim)
            sp_reward,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_done,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_valid_transition,  # (n_agents, batch_size, 1)
            sp_term,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            sp_next_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            sp_next_obs,  # (n_agents, batch_size, dim)
            sp_next_available_actions,  # (n_agents, batch_size, dim)
            sp_gamma,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
        ) = data
        
        demo_data = self.demo_buffer.sample()
        (
            demo_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            demo_obs,  # (n_agents, batch_size, dim)
            demo_actions,  # (n_agents, batch_size, dim)
            demo_available_actions,  # (n_agents, batch_size, dim)
            demo_reward,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            demo_done,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            demo_valid_transition,  # (n_agents, batch_size, 1)
            demo_term,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
            demo_next_share_obs,  # EP: (batch_size, dim), FP: (n_agents * batch_size, dim)
            demo_next_obs,  # (n_agents, batch_size, dim)
            demo_next_available_actions,  # (n_agents, batch_size, dim)
            demo_gamma,  # EP: (batch_size, 1), FP: (n_agents * batch_size, 1)
        ) = demo_data
        
        # train critic
        all_obs = np.concatenate([demo_obs, sp_obs], axis=1)
        all_share_obs = np.concatenate([demo_share_obs, sp_share_obs], axis=0)
        all_next_obs = np.concatenate([demo_next_obs, sp_next_obs], axis=1)
        all_next_share_obs = np.concatenate([demo_next_share_obs, sp_next_share_obs], axis=0)
        all_actions = np.concatenate([demo_actions, sp_actions], axis=1)
        all_done = np.concatenate([demo_done, sp_done], axis=0)
        all_term = np.concatenate([demo_term, sp_term], axis=0)
        all_valid_transition = np.concatenate([demo_valid_transition, sp_valid_transition], axis=1)
        all_gamma = np.concatenate([demo_gamma, sp_gamma], axis=0)

        self.critic.turn_on_grad()
        self.critic.train_MAFIS(
            all_obs,
            all_share_obs,
            all_actions,
            all_done,
            all_valid_transition,
            all_term,
            all_next_obs,
            all_next_share_obs,
            all_gamma,
            self.actor,
            self.value_normalizer,
        )
        self.critic.turn_off_grad()
        
        if self.algo_args["train"]["sac"]:
            for agent_id in range(self.num_agents):
                self.actor[agent_id].turn_on_grad()
                # train this agent
                a, logp = self.actor[agent_id].get_actions_with_logprobs(
                    sp_obs[agent_id], None
                )
                value_pred = self.critic.get_values(sp_obs[agent_id], a)
                actor_loss = -torch.mean(value_pred - self.alpha[agent_id] * logp)
                self.actor[agent_id].actor_optimizer.zero_grad()
                actor_loss.backward()
                self.actor[agent_id].actor_optimizer.step()
                self.actor[agent_id].turn_off_grad()
        self.critic.soft_update()
